import numpy as np
import torch
import time
import nvdiffrast.torch as dr
from PIL import Image

from CRM.util.utils import get_tri

class CRMSampler:
    @classmethod
    def process_pixel_img(
        cls,
        pixel_img: Image.Image,
        bg_color = "#7F7F7F"
    ):
        # expand image to 1:1 squared and add fixed background color
        width, height = pixel_img.size
        new_size = (max(width, height), max(width, height))
        new_image = Image.new("RGBA", new_size, bg_color)
        paste_position = ((new_size[0] - width) // 2, (new_size[1] - height) // 2)
        new_image.paste(pixel_img, paste_position)
        
        # given an RGBA image fixed background color, alpha channel is used as mask to add background color
        background = Image.new("RGBA", new_image.size, bg_color)
        return Image.alpha_composite(background, new_image).convert("RGB")
    
    @classmethod
    def stage1_sample(
        cls,
        stage1_sampler,
        pixel_img: Image.Image,
        prompt="3D assets",
        neg_texts="uniform low no texture ugly, boring, bad anatomy, blurry, pixelated,  obscure, unnatural colors, poor lighting, dull, and unclear.",
        seed=0,
        scale=5,
        step=50,
        ddim_eta=0.0
    ):
        """
            The first stage was condition on single pixel image, gererate multi-view pixel image, based on the v2pp config
        """
                
        stage1_sampler.seed = seed
                
        uc = stage1_sampler.model.get_learned_conditioning([neg_texts]).to(stage1_sampler.device)
        stage1_images = stage1_sampler.i2i(
            stage1_sampler.model,
            stage1_sampler.size,
            prompt,
            uc=uc,
            sampler=stage1_sampler.sampler,
            ip=pixel_img,
            step=step,
            scale=scale,
            batch_size=stage1_sampler.batch_size,
            ddim_eta=ddim_eta,
            dtype=stage1_sampler.dtype,
            device=stage1_sampler.device,
            camera=stage1_sampler.camera,
            num_frames=stage1_sampler.num_frames,
            pixel_control=(stage1_sampler.mode == "pixel"),
            transform=stage1_sampler.image_transform,
            offset_noise=stage1_sampler.offset_noise,
        )
        
        stage1_images = stage1_images[:-1] # remove reference view
        return stage1_images # (N, H, W, 3) in [0, 255]
    
    @classmethod
    def stage2_sample(
        cls, 
        stage2_sampler,
        pixel_img: Image.Image,
        stage1_images, 
        prompt="3D assets",
        neg_texts="uniform low no texture ugly, boring, bad anatomy, blurry, pixelated,  obscure, unnatural colors, poor lighting, dull, and unclear.",
        seed=0,
        scale=5,
        step=50
    ):
        """
            The second stage was condition on multiview pixel image generated by the first stage, generate the final image, based on the stage2-test config
        """
        
        stage2_sampler.seed = seed
        
        # Convert torch image to PIL.Image
        stage1_images = list((255.0 * stage1_images.cpu().numpy()).astype(np.uint8))
        stage1_images = [Image.fromarray(img) for img in stage1_images]
        
        uc = stage2_sampler.model.get_learned_conditioning([neg_texts]).to(stage2_sampler.device)
        stage2_images = stage2_sampler.i2iStage2(
            stage2_sampler.model,
            stage2_sampler.size,
            prompt,
            uc=uc,
            sampler=stage2_sampler.sampler,
            pixel_images=stage1_images,
            ip=pixel_img,
            step=step,
            scale=scale,
            batch_size=stage2_sampler.batch_size,
            ddim_eta=0.0,
            dtype=stage2_sampler.dtype,
            device=stage2_sampler.device,
            camera=stage2_sampler.camera,
            num_frames=stage2_sampler.num_frames,
            pixel_control=(stage2_sampler.mode == "pixel"),
            transform=stage2_sampler.image_transform,
            offset_noise=stage2_sampler.offset_noise,
        )

        return stage2_images # (N, H, W, 3) in [0, 1]
    
    @classmethod
    def generate3d(cls, crm_model, rgb, ccm, device):

        color_tri = torch.from_numpy(rgb)
        xyz_tri = torch.from_numpy(ccm[:,:,(2,1,0)])
        color = color_tri.permute(2,0,1)
        xyz = xyz_tri.permute(2,0,1)


        def get_imgs(color):
            # color : [C, H, W*6]
            color_list = []
            color_list.append(color[:,:,256*5:256*(1+5)])
            for i in range(0,5):
                color_list.append(color[:,:,256*i:256*(1+i)])
            return torch.stack(color_list, dim=0)# [6, C, H, W]
        
        triplane_color = get_imgs(color).permute(0,2,3,1).unsqueeze(0).to(device)# [1, 6, H, W, C]

        color = get_imgs(color)
        xyz = get_imgs(xyz)

        color = get_tri(color, dim=0, blender= True, scale = 1).unsqueeze(0)
        xyz = get_tri(xyz, dim=0, blender= True, scale = 1, fix= True).unsqueeze(0)

        triplane = torch.cat([color,xyz],dim=1).to(device)
        # 3D visualize
        crm_model.eval()
        glctx = dr.RasterizeCudaContext()

        if crm_model.denoising == True:
            tnew = 20
            tnew = torch.randint(tnew, tnew+1, [triplane.shape[0]], dtype=torch.long, device=triplane.device)
            noise_new = torch.randn_like(triplane) *0.5+0.5
            triplane = crm_model.scheduler.add_noise(triplane, noise_new, tnew)    
            start_time = time.time()
            with torch.no_grad():
                triplane_feature2 = crm_model.unet2(triplane,tnew)
            end_time = time.time()
            elapsed_time = end_time - start_time
            print(f"Running unet takes {elapsed_time}s")
        else:
            triplane_feature2 = crm_model.unet2(triplane)
            

        with torch.no_grad():
            data_config = {
                'resolution': [1024, 1024],
                "triview_color": triplane_color.to(device),
            }

            verts, faces = crm_model.decode(data_config, triplane_feature2)

            data_config['verts'] = verts[0]
            data_config['faces'] = faces
            

        from kiui.mesh_utils import clean_mesh
        verts, faces = clean_mesh(data_config['verts'].squeeze().cpu().numpy().astype(np.float32), data_config['faces'].squeeze().cpu().numpy().astype(np.int32), repair = False, remesh=False, remesh_size=0.005)
        data_config['verts'] = torch.from_numpy(verts).cuda().contiguous()
        data_config['faces'] = torch.from_numpy(faces).cuda().contiguous()

        start_time = time.time()
        
        mesh = crm_model.get_mesh_wt_uv(glctx, data_config, device, res=(1024,1024), tri_fea_2=triplane_feature2)    

        end_time = time.time()
        elapsed_time = end_time - start_time
        print(f"Extracting mesh & texture takes {elapsed_time}s")
        
        return mesh